训练机器翻译

Note

本节实现一个可以训练机器翻译模型的函数。

损失函数

每个时间步,解码器预测了输出词元的概率分布。类似于语言模型,可以使用softmax获得分布,并通过计算交叉熵损失来进行优化。

但是填充词元应该被排除在损失函数的计算之外。

import torch
from torch import nn
import d2l


#@save
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不想关的项"""
    # `X` shape: (`batch_size`, `num_steps`)
    # [None, :] makes (`num_steps`,) to (1, `num_steps`)
    # [:, None] makes (`batch_size`) to (`batch_size`, 1)
    mask = torch.arange((X.size(1)), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X
#@save
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带屏蔽的softmax交叉熵损失函数"""
    # shape of pred: (`batch_size`, `num_steps`, `vocab_size`)
    # shape of label: (`batch_size`, `num_steps`)
    # shape of valid_len: (`batch_size`,)
    def forward(self, pred, label, valid_len):
        # 非pad为1,pad为0
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        # 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken
        self.reduction = 'none'
        # nn.CrossEntropyLoss((`batch_size`, `vocab_size`, `num_steps`), (`batch_size`, `num_steps`))
        unweighted_loss = super().forward(pred.permute(0, 2, 1), label)
        # 得到带屏蔽的损失
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss

训练

训练时,我们解码器的输入不是采样自上一步的输出,而是<bos> + 真实的输出序列,这被称为teacher-forcing,它能使我们训练得更快。

#@save
def train_nmt(net, data_iter, lr, num_epochs, tgt_vocab):
    """训练机器翻译模型"""
    device = d2l.try_gpu()
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss_fn = MaskedSoftmaxCELoss()
    net.train()  # 用了Dropout,必须明示
    # 画带屏蔽的交叉熵损失
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        # 损失和,tokens总数
        metric = d2l.Accumulator(2)
        for batch in data_iter:
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            # 解码器的输入是<bos>+真实输出序列 
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)
            # 模型需是Encoder-Decoder结构
            Y_hat, _ = net(X, dec_input, X_valid_len)
            
            # Backpropagation
            optimizer.zero_grad()
            loss = loss_fn(Y_hat, Y, Y_valid_len)
            loss.sum().backward()
            optimizer.step()
            # 记录数据
            with torch.no_grad():
                metric.add(loss.sum(), Y_valid_len.sum())
        # 画图
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))

预测

预测时,我们没有真实的输出序列,解码器当前时间步的输入都将来自于前一时间步的输出词元。

jupyter

#@save
def predict_nmt(net, src_sentence, src_vocab, tgt_vocab, num_steps, device=d2l.try_gpu()):
    """机器翻译模型做预测"""
    net.eval()
    # 处理src_sentence
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    enc_X = torch.unsqueeze(
        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    enc_outputs = net.encoder(enc_X, enc_valid_len)
    # 解码器初始state及初始输入
    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
    dec_X = torch.unsqueeze(
        torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
    output_seq, attention_weight_seq = [], []
    # 一步一步来
    for _ in range(num_steps):
        Y, dec_state = net.decoder(dec_X, dec_state)
        # We use the token with the highest prediction likelihood as the input
        # of the decoder at the next time step
        dec_X = Y.argmax(dim=2)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        # Once the end-of-sequence token is predicted, the generation of the
        # output sequence is complete
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    return ' '.join(tgt_vocab.to_tokens(output_seq))

评估

我们可以通过与真实标签序列做比较来评估预测序列。

\(p_{n}\) 表示 \(n\)元语法的精确度,它是两个数量的比值,分子是预测序列与标签序列中匹配的 \(n\)元语法的数量,分母是预测序列中 \(n\)元语法的数量。

那么, BLEU 的定义是:

\[\exp\left(\min\left(0, 1 - \frac{\mathrm{len}_{\text{label}}}{\mathrm{len}_{\text{pred}}}\right)\right)\prod_{i=1}^{k}p_{n}^{1/{2^{n}}}\]

其中 \(k\) 是用于匹配的最长 \(n\)元语法,指数项用于惩罚较短的预测序列。

#@save
def bleu(pred_seq, label_seq, k):
    """计算 BLEU"""
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    # 计算n元语法的精确度
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        # 统计标签序列中各n元语法的数量
        for i in range(len_label - n + 1):
            label_subs[''.join(label_tokens[i:i + n])] += 1
        # 计算匹配
        for i in range(len_pred - n + 1):
            if label_subs[''.join(pred_tokens[i:i + n])] > 0:
                num_matches += 1
                label_subs[''.join(pred_tokens[i:i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score